{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lmfit example using the ``brute`` method (a.k.a. grid search)\n",
    "\n",
    "\n",
    "This notebook shows a simple example of using [``lmfit.minimize.brute``](http://lmfit.github.io/lmfit-py/fitting.html#lmfit.minimizer.Minimizer.brute) that uses the method with the same name from [``scipy.optimize``](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brute.html#scipy.optimize.brute). The method computes the function’s value at each point of a multidimensional grid of points, to find the global minimum of the function. It behaves identically to ``scipy.optimize.brute`` in case finite bounds are given on all varying parameters, but will also deal with non-bounded parameters (see below)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preliminary imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "\n",
    "from lmfit import minimize, fit_report, Minimizer, Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the [example](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brute.html) given in the documentation of SciPy: \"We illustrate the use of brute to seek the global minimum of a function of two variables that is given as the sum of a positive-definite quadratic and two deep “Gaussian-shaped” craters. Specifically, define the objective function f as the sum of three other functions, ``f = f1 + f2 + f3``. We suppose each of these has a signature ``(z, *params), where z = (x, y)``, and params and the functions are as defined below.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we create a set of Parameters where all variables except ``x`` and ``y`` are given fixed values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = Parameters()\n",
    "params.add_many(\n",
    "        ('a', 2, False),\n",
    "        ('b', 3, False),\n",
    "        ('c', 7, False),\n",
    "        ('d', 8, False),\n",
    "        ('e', 9, False),\n",
    "        ('f', 10, False),\n",
    "        ('g', 44, False),\n",
    "        ('h', -1, False),\n",
    "        ('i', 2, False),\n",
    "        ('j', 26, False),\n",
    "        ('k', 1, False),\n",
    "        ('l', -2, False),\n",
    "        ('scale', 0.5, False),\n",
    "        ('x', 0.0, True),\n",
    "        ('y', 0.0, True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Second, create the three functions and the objective function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(p):\n",
    "    par = p.valuesdict()\n",
    "    return (par['a'] * par['x']**2 + par['b'] * par['x'] * par['y'] +\n",
    "            par['c'] * par['y']**2 + par['d']*par['x'] + par['e']*par['y'] + par['f'])\n",
    "\n",
    "\n",
    "def f2(p):\n",
    "    par = p.valuesdict()\n",
    "    return (-1.0*par['g']*np.exp(-((par['x']-par['h'])**2 +\n",
    "            (par['y']-par['i'])**2) / par['scale']))\n",
    "\n",
    "\n",
    "def f3(p):\n",
    "    par = p.valuesdict()\n",
    "    return (-1.0*par['j']*np.exp(-((par['x']-par['k'])**2 +\n",
    "            (par['y']-par['l'])**2) / par['scale']))\n",
    "\n",
    "\n",
    "def f(params):\n",
    "    return f1(params) + f2(params) + f3(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   Just as in the documentation we will do a grid search between ``-4`` and ``4`` and use a stepsize of ``0.25``. The bounds can be set as usual with the ``min`` and ``max`` attributes, and the stepsize is set using ``brute_step``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params['x'].set(min=-4, max=4, brute_step=0.25)\n",
    "params['y'].set(min=-4, max=4, brute_step=0.25)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Performing a grid search is done with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fitter = Minimizer(f, params)\n",
    "result = fitter.minimize(method='brute')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ", which will increment ``x`` and ``y`` between ``-4`` in increments of ``0.25`` until ``4`` (not inclusive)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_x, grid_y = [np.unique(par.ravel()) for par in result.brute_grid]\n",
    "print(grid_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The objective function is evaluated on this grid, and the raw output from ``scipy.optimize.brute`` is stored in the MinimizerResult as ``brute_<parname>`` attributes. These attributes are: \n",
    "\n",
    "``result.brute_x0`` -- A 1-D array containing the coordinates of a point at which the objective function had its minimum value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.brute_x0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "``result.brute_fval`` -- Function value at the point x0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.brute_fval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "``result.brute_grid`` -- Representation of the evaluation grid. It has the same length as x0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " result.brute_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "``result.brute_Jout`` -- Function values at each point of the evaluation grid, i.e., Jout = func(*grid)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.brute_Jout"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Reassuringly, the obtained results are indentical to using the method in SciPy directly!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: fit a decaying sine wave"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, will explain some of the options\n",
    "\n",
    "We start off by generating some synthetic data with noise for a decaying sine wave, define an objective function and create a Parameter set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create data to be fitted\n",
    "x = np.linspace(0, 15, 301)\n",
    "np.random.seed(7)\n",
    "data = (5. * np.sin(2*x - 0.1) * np.exp(-x*x*0.025) + \n",
    "        np.random.normal(size=len(x), scale=0.2))\n",
    "plt.plot(x, data, 'b')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define objective function: returns the array to be minimized\n",
    "def fcn2min(params, x, data):\n",
    "    \"\"\"Model decaying sine wave, subtract data.\"\"\"\n",
    "    amp = params['amp']\n",
    "    shift = params['shift']\n",
    "    omega = params['omega']\n",
    "    decay = params['decay']\n",
    "    model = amp * np.sin(x*omega + shift) * np.exp(-x*x*decay)\n",
    "    return model - data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a set of Parameters\n",
    "params = Parameters()\n",
    "params.add('amp', value=7, min=2.5)\n",
    "params.add('decay', value=0.05)\n",
    "params.add('shift', value=0.0, min=-np.pi/2., max=np.pi/2)\n",
    "params.add('omega', value=3, max=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast to the implementation in SciPy (as shown in the first example), varying parameters do not need to have finite bounds in lmfit. However, in that case they **do** need the ``brute_step`` attribute specified, so let's do that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params['amp'].set(brute_step=0.25)\n",
    "params['decay'].set(brute_step=0.005)\n",
    "params['omega'].set(brute_step=0.25)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our initial parameter set is now defined as shown below and this will determine how the grid is set-up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we initialize a Minimizer and perform the grid search:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fitter = Minimizer(fcn2min, params, fcn_args=(x, data))\n",
    "result_brute = fitter.minimize(method='brute', Ns=25, keep=25)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We used two new parameters here: ``Ns`` and ``keep``. The parameter ``Ns`` determines the 'number of grid points along the axes' similarly to its usage in SciPy. Together with ``brute_step``, ``min`` and ``max`` for a Parameter it will dictate how the grid is set-up:\n",
    "\n",
    "**(1)** finite bounds are specified (\"SciPy implementation\"): uses ``brute_step`` if present (in the example above) or uses ``Ns`` to generate the grid. The latter scenario that interpolates ``Ns`` points from ``min`` to ``max`` (inclusive), is here shown for the parameter ``shift``:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "par_name = 'shift'\n",
    "indx_shift = result_brute.var_names.index(par_name)\n",
    "grid_shift = np.unique(result_brute.brute_grid[indx_shift].ravel())\n",
    "print(\"parameter = {}: \\nnumber of steps = {}\\ngrid = {}\".format(par_name, len(grid_shift), grid_shift))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If finite bounds are not set for a certain parameter then the user **must** specify ``brute_step`` - three more scenarios are considered here:\n",
    "\n",
    "**(2)** lower bound (min) and brute_step are specified: range = (min, min + Ns * brute_step, brute_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "par_name = 'amp'\n",
    "indx_shift = result_brute.var_names.index(par_name)\n",
    "grid_shift = np.unique(result_brute.brute_grid[indx_shift].ravel())\n",
    "print(\"parameter = {}: \\nnumber of steps = {}\\ngrid = {}\".format(par_name, len(grid_shift), grid_shift))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(3)** upper bound (max) and brute_step are specified: range = (max - Ns * brute_step, max, brute_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "par_name = 'omega'\n",
    "indx_shift = result_brute.var_names.index(par_name)\n",
    "grid_shift = np.unique(result_brute.brute_grid[indx_shift].ravel())\n",
    "print(\"parameter = {}: \\nnumber of steps = {}\\ngrid = {}\".format(par_name, len(grid_shift), grid_shift))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(4)** numerical value (value) and brute_step are specified: range = (value - (Ns//2) * brute_step, value + (Ns//2) * brute_step, brute_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "par_name = 'decay'\n",
    "indx_shift = result_brute.var_names.index(par_name)\n",
    "grid_shift = np.unique(result_brute.brute_grid[indx_shift].ravel())\n",
    "print(\"parameter = {}: \\nnumber of steps = {}\\ngrid = {}\".format(par_name, len(grid_shift), grid_shift))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [``MinimizerResult``](http://lmfit.github.io/lmfit-py/fitting.html#lmfit.minimizer.MinimizerResult) contains all the usual best-fit parameters and fitting statistics. For example, the optimal solution from the grid search is given below together with a plot:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fit_report(result_brute))\n",
    "plt.plot(x, data, 'b')\n",
    "plt.plot(x, data + fcn2min(result_brute.params, x, data), 'r--')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that this fit is already very good, which is what we should expect since our ``brute`` force grid is sampled rather finely and encompasses the \"correct\" values.\n",
    "\n",
    "In a more realistic, complicated example the ``brute`` method will be used to get reasonable values for the parameters and perform another minimization (e.g., using ``leastsq``) using those as starting values. That is where the ``keep`` parameter comes into play: it determines the \"number of best candidates from the brute force method that are stored in the ``candidates`` attribute\". In the example above we store the best-ranking 25 solutions (the default value is ``50`` and storing all the grid points can be accomplished by choosing ``all``). The ``candidates`` attribute contains the parameters and ``chisqr`` from the brute force method as a namedtuple, ``(‘Candidate’, [‘params’, ‘score’])``, sorted on the (lowest) ``chisqr`` value. To access the values for a particular candidate one can use ``result.candidate[#].params`` or ``result.candidate[#].score``, where a lower # represents a better candidate. The ``show_candidates(#)`` uses the ``pretty_print()`` method to show a specific candidate-# or all candidates when no number is specified.\n",
    "\n",
    "The optimal fit is, as usual, stored in the ``MinimizerResult.params`` attribute and is, therefore, identical to ``result_brute.show_candidates(0)``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_brute.show_candidates(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " In this case, the next-best scoring candidate has already a ``chisqr`` that increased quite a bit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_brute.show_candidates(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and is, therefore, probably not so likely... However, as said above, in most cases you'll want to do another minimization using the solutions from the ``brute`` method as starting values. That can be easily accomplished as shown in the code below, where we now perform a ``leastsq`` minimization starting from the top-25 solutions and accept the solution if the ``chisqr`` is lower than the previously 'optimal' solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_result = copy.deepcopy(result_brute)\n",
    "\n",
    "for candidate in result_brute.candidates:\n",
    "    trial = fitter.minimize(method='leastsq', params=candidate.params)\n",
    "    if trial.chisqr < best_result.chisqr:\n",
    "        best_result = trial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the ``leastsq`` minimization we obtain the following parameters for the most optimal result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fit_report(best_result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected the parameters have not changed significantly as they were already very close to the \"real\" values, which can also be appreciated from the plots below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, data, 'b')\n",
    "plt.plot(x, data + fcn2min(result_brute.params, x, data), 'r--', label='brute')\n",
    "plt.plot(x, data + fcn2min(best_result.params, x, data), 'g--', label='brute followed by leastsq')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, the results from the ``brute`` force grid-search can be visualized using the rather lengthy Python function below (which might get incorporated  in lmfit at some point)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results_brute(result, best_vals=True, varlabels=None,\n",
    "                       output=None):\n",
    "    \"\"\"Visualize the result of the brute force grid search.\n",
    "\n",
    "    The output file will display the chi-square value per parameter and contour\n",
    "    plots for all combination of two parameters.\n",
    "\n",
    "    Inspired by the `corner` package (https://github.com/dfm/corner.py).\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    result : :class:`~lmfit.minimizer.MinimizerResult`\n",
    "        Contains the results from the :meth:`brute` method.\n",
    "\n",
    "    best_vals : bool, optional\n",
    "        Whether to show the best values from the grid search (default is True).\n",
    "\n",
    "    varlabels : list, optional\n",
    "        If None (default), use `result.var_names` as axis labels, otherwise\n",
    "        use the names specified in `varlabels`.\n",
    "\n",
    "    output : str, optional\n",
    "        Name of the output PDF file (default is 'None')\n",
    "    \"\"\"\n",
    "    from matplotlib.colors import LogNorm\n",
    "\n",
    "    npars = len(result.var_names)\n",
    "    fig, axes = plt.subplots(npars, npars)\n",
    "\n",
    "    if not varlabels:\n",
    "        varlabels = result.var_names\n",
    "    if best_vals and isinstance(best_vals, bool):\n",
    "        best_vals = result.params\n",
    "\n",
    "    for i, par1 in enumerate(result.var_names):\n",
    "        for j, par2 in enumerate(result.var_names):\n",
    "\n",
    "            # parameter vs chi2 in case of only one parameter\n",
    "            if npars == 1:\n",
    "                axes.plot(result.brute_grid, result.brute_Jout, 'o', ms=3)\n",
    "                axes.set_ylabel(r'$\\chi^{2}$')\n",
    "                axes.set_xlabel(varlabels[i])\n",
    "                if best_vals:\n",
    "                    axes.axvline(best_vals[par1].value, ls='dashed', color='r')\n",
    "\n",
    "            # parameter vs chi2 profile on top\n",
    "            elif i == j and j < npars-1:\n",
    "                if i == 0:\n",
    "                    axes[0, 0].axis('off')\n",
    "                ax = axes[i, j+1]\n",
    "                red_axis = tuple([a for a in range(npars) if a != i])\n",
    "                ax.plot(np.unique(result.brute_grid[i]),\n",
    "                        np.minimum.reduce(result.brute_Jout, axis=red_axis),\n",
    "                        'o', ms=3)\n",
    "                ax.set_ylabel(r'$\\chi^{2}$')\n",
    "                ax.yaxis.set_label_position(\"right\")\n",
    "                ax.yaxis.set_ticks_position('right')\n",
    "                ax.set_xticks([])\n",
    "                if best_vals:\n",
    "                    ax.axvline(best_vals[par1].value, ls='dashed', color='r')\n",
    "\n",
    "            # parameter vs chi2 profile on the left\n",
    "            elif j == 0 and i > 0:\n",
    "                ax = axes[i, j]\n",
    "                red_axis = tuple([a for a in range(npars) if a != i])\n",
    "                ax.plot(np.minimum.reduce(result.brute_Jout, axis=red_axis),\n",
    "                        np.unique(result.brute_grid[i]), 'o', ms=3)\n",
    "                ax.invert_xaxis()\n",
    "                ax.set_ylabel(varlabels[i])\n",
    "                if i != npars-1:\n",
    "                    ax.set_xticks([])\n",
    "                elif i == npars-1:\n",
    "                    ax.set_xlabel(r'$\\chi^{2}$')\n",
    "                if best_vals:\n",
    "                    ax.axhline(best_vals[par1].value, ls='dashed', color='r')\n",
    "\n",
    "            # contour plots for all combinations of two parameters\n",
    "            elif j > i:\n",
    "                ax = axes[j, i+1]\n",
    "                red_axis = tuple([a for a in range(npars) if a != i and a != j])\n",
    "                X, Y = np.meshgrid(np.unique(result.brute_grid[i]),\n",
    "                                   np.unique(result.brute_grid[j]))\n",
    "                lvls1 = np.linspace(result.brute_Jout.min(),\n",
    "                                    np.median(result.brute_Jout)/2.0, 7, dtype='int')\n",
    "                lvls2 = np.linspace(np.median(result.brute_Jout)/2.0,\n",
    "                                    np.median(result.brute_Jout), 3, dtype='int')\n",
    "                lvls = np.unique(np.concatenate((lvls1, lvls2)))\n",
    "                ax.contourf(X.T, Y.T, np.minimum.reduce(result.brute_Jout, axis=red_axis),\n",
    "                            lvls, norm=LogNorm())\n",
    "                ax.set_yticks([])\n",
    "                if best_vals:\n",
    "                    ax.axvline(best_vals[par1].value, ls='dashed', color='r')\n",
    "                    ax.axhline(best_vals[par2].value, ls='dashed', color='r')\n",
    "                    ax.plot(best_vals[par1].value, best_vals[par2].value, 'rs', ms=3)\n",
    "                if j != npars-1:\n",
    "                    ax.set_xticks([])\n",
    "                elif j == npars-1:\n",
    "                    ax.set_xlabel(varlabels[i])\n",
    "                if j - i >= 2:\n",
    "                    axes[i, j].axis('off')\n",
    "                    \n",
    "    if output is not None:\n",
    "        plt.savefig(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and finally, to generated the figure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results_brute(result_brute, best_vals=True, varlabels=None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
